{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert a PyTorch model to Tensorflow using ONNX\n",
    "\n",
    "In this tutorial, we will show you how to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow to run it. We will also show you how to save this Tensorflow model into a file for later use.  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installations\n",
    "\n",
    "First let's install [ONNX](https://github.com/onnx/onnx), [PyTorch](https://github.com/pytorch/pytorch), and [Tensorflow](https://github.com/tensorflow/tensorflow) by following the instructions on each of their repository.\n",
    "\n",
    "Then Install torchvision by the following command:\n",
    "```\n",
    "pip install torchvision\n",
    "```\n",
    "\n",
    "Next install [onnx-tensorflow](https://github.com/onnx/onnx-tensorflow) by the following commands:\n",
    "```\n",
    "git clone git@github.com:onnx/onnx-tensorflow.git && cd onnx-tensorflow\n",
    "pip install -e .\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define model\n",
    "\n",
    "In this tutorial we are going to use the [MNIST model](https://github.com/pytorch/examples/tree/master/mnist) from PyTorch examples.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and test model\n",
    "Now let's train this model. By default if GPU is availalbe on your environment, it will use GPU instead of CPU to run the training. In this tutorial we will train this model with 60 epochs. It will takes about 15 minutes on an environment with 1 GPU to complete this training. You can always adjust the number of epoch base on how well you want your model to be trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "\n",
    "def test(args, model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))\n",
    "    \n",
    "# Training settings\n",
    "parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n",
    "parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n",
    "                    help='input batch size for training (default: 64)')\n",
    "parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n",
    "                    help='input batch size for testing (default: 1000)')\n",
    "parser.add_argument('--epochs', type=int, default=10, metavar='N',\n",
    "                    help='number of epochs to train (default: 10)')\n",
    "parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n",
    "                    help='learning rate (default: 0.01)')\n",
    "parser.add_argument('--momentum', type=float, default=0.5, metavar='M',\n",
    "                    help='SGD momentum (default: 0.5)')\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='disables CUDA training')\n",
    "parser.add_argument('--seed', type=int, default=1, metavar='S',\n",
    "                    help='random seed (default: 1)')\n",
    "parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n",
    "                    help='how many batches to wait before logging training status')\n",
    "\n",
    "# Train this model with 60 epochs and after process every 300 batches log the train status \n",
    "args = parser.parse_args(['--epochs', '60', '--log-interval', '300'])\n",
    "\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                    transform=transforms.Compose([\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.1307,), (0.3081,))\n",
    "                    ])),\n",
    "    batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize((0.1307,), (0.3081,))\n",
    "                    ])),\n",
    "    batch_size=args.test_batch_size, shuffle=True, **kwargs)\n",
    "\n",
    "\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, device, train_loader, optimizer, epoch)\n",
    "    test(args, model, device, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the trained model to a file\n",
    "torch.save(model.state_dict(), 'output/mnist.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export the trained model to ONNX \n",
    "In order to export the model, Pytorch exporter needs to run the model once and save this resulting traced model to a ONNX file. Therefore, we need to provide the input for our MNIST model. Here we would expect to get a black and white 28 x 28 picture as an input to run this model in inference phase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.autograd import Variable\n",
    "\n",
    "# Load the trained model from file\n",
    "trained_model = Net()\n",
    "trained_model.load_state_dict(torch.load('output/mnist.pth'))\n",
    "\n",
    "# Export the trained model to ONNX\n",
    "dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model\n",
    "torch.onnx.export(trained_model, dummy_input, \"output/mnist.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PS. You can examine the graph of this mnist.onnx file using an ONNX viewer call [Netron](https://github.com/lutzroeder/Netron)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import the ONNX model to Tensorflow\n",
    "We will use onnx_tf.backend.prepare to import the ONNX model into Tensorflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "from onnx_tf.backend import prepare\n",
    "\n",
    "# Load the ONNX file\n",
    "model = onnx.load('output/mnist.onnx')\n",
    "\n",
    "# Import the ONNX model to Tensorflow\n",
    "tf_rep = prepare(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's explore the tf_rep object return from onnx.tf.backend.prepare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input nodes to the model\n",
    "print('inputs:', tf_rep.inputs)\n",
    "\n",
    "# Output nodes from the model\n",
    "print('outputs:', tf_rep.outputs)\n",
    "\n",
    "# All nodes in the model\n",
    "print('tensor_dict:')\n",
    "print(tf_rep.tensor_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the model in Tensorflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from IPython.display import display\n",
    "from PIL import Image\n",
    "\n",
    "print('Image 1:')\n",
    "img = Image.open('assets/two.png').resize((28, 28)).convert('L')\n",
    "display(img)\n",
    "output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])\n",
    "print('The digit is classified as ', np.argmax(output))\n",
    "\n",
    "print('Image 2:')\n",
    "img = Image.open('assets/three.png').resize((28, 28)).convert('L')\n",
    "display(img)\n",
    "output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])\n",
    "print('The digit is classified as ', np.argmax(output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the Tensorflow model into a file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_rep.export_graph('output/mnist.pb')"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
